import torch
import torch.nn as nn


class StationGatedFusion(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.proj = nn.Linear(2*hidden, 1)
        self.num_site = 39

    def forward(self, x_loc, x_glb2st):
        x = torch.cat([x_loc, x_glb2st], dim=-1)
        g = torch.sigmoid(self.proj(x))
        res = g * x_loc + (1 - g) * x_glb2st
        b, t, _ = res.shape
        res = res.reshape(b, t, self.num_site, -1)
        return res